
#include "catch2/catch.hpp"
#include "utils/util_functions.h"

TEST_CASE("Function") {
  SECTION("IOU") {
    float iou = CalcIoU({10, 20, 30, 40}, {14, 31, 80, 90});
    REQUIRE(iou == Approx(0.03469).epsilon(0.001));
  }

  SECTION("NMS") {
    std::vector<std::vector<float>> boxes({{154, 88, 209, 120, 0.800510},
                                           {154, 88, 210, 121, 0.864651},
                                           {200, 124, 253, 162, 0.934630},
                                           {200, 124, 253, 163, 0.927808},
                                           {155, 88, 209, 121, 0.813166},
                                           {154, 88, 209, 121, 0.861870},
                                           {200, 124, 253, 162, 0.928783},
                                           {200, 123, 253, 163, 0.921736},
                                           {200, 124, 253, 162, 0.845653},
                                           {200, 124, 253, 163, 0.825735},
                                           {153, 87, 210, 121, 0.808621},
                                           {199, 123, 256, 162, 0.804178},
                                           {199, 124, 254, 163, 0.813115},
                                           {199, 124, 254, 162, 0.830227},
                                           {154, 88, 209, 121, 0.848950},
                                           {200, 124, 253, 164, 0.899599},
                                           {201, 124, 253, 162, 0.903123}});
    std::vector<std::vector<float>> nms_boxes = NMS(boxes, 0.7);
    REQUIRE(nms_boxes.size() == 2);
    CHECK(nms_boxes[0] == std::vector<float>({200, 124, 253, 162, 0.934630}));
    CHECK(nms_boxes[1] == std::vector<float>({154, 88, 210, 121, 0.864651}));
  }
}
